{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to define TabularCPD and LinearGaussianCPD\n",
    "\n",
    "One can implement `TabularCPD` for discrete variables inside `DiscreteBayesianNetwork`.\n",
    "and `LinearGaussianCPD` for continuous variables inside `LinearGaussianBayesianNetwork`. \n",
    "\n",
    "In this tutorial, we will demonstrate how to define each CPD.\n",
    "\n",
    "## TabularCPD for discrete variables\n",
    "\n",
    "In tabular CPD, the probability for discrete variable is given as a table. Let us start with examples for independent variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pgmpy.factors.discrete import TabularCPD\n",
    "\n",
    "cpd_coin_fair = TabularCPD(variable=\"coin\", variable_card=2, values=[[0.5], [0.5]], state_names={'Coin': ['Head', 'Tail']})\n",
    "cpd_coin_biased = TabularCPD(variable=\"coin\", variable_card=2, values=[[0.9], [0.1]], state_names={'Coin': ['Head', 'Tail']})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a coin flip, we have discrete number (2) possible outcomes, therefore `variable_card=2` is passed. `values` pass probabilities for each outcome, which sum up to 1. `state_names` is optional, it gives names for each outcome."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "cpd_smoke = TabularCPD(variable=\"Smoker\", variable_card=2, values=[[0.3], [0.7]], state_names={'Smoker': ['Non-smoker', 'Smoker']})\n",
    "cpd_pollution = TabularCPD(variable=\"Pollution\", variable_card=3, values=[[0.7], [0.29], [0.01]], state_names={'Pollution': ['Clean', 'Bad', 'Fatal']})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `Pollution` and `Smoker` variables do not depend on other variables. They take categorical values. `variable_card` denotes how many categories the variable can take. For example `Smoker` can take a binary value since `variable_card=2`. `values` is an array of probability values for each category. Note that the probability `values` sum up to 1.\n",
    "\n",
    "### Tabular CPD with multiple evidence.\n",
    "\n",
    "We next consider another discrete variable `Cancer`, with a model assumption that it depends on `Pollution` and `Smoker`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "cpd_cancer = TabularCPD(\n",
    "    variable=\"Cancer\",\n",
    "    variable_card=2,\n",
    "    values=[[0.20, 0.15, 0.03, 0.05, 0.001, 0.02],\n",
    "            [0.80, 0.85, 0.97, 0.95, 0.999, 0.98]],\n",
    "    evidence=[\"Smoker\", \"Pollution\"],\n",
    "    evidence_card=[2, 3],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For `Cancer` variable, we pass another `TabularCPD`. `evidence_card` denotes cardinality of each evidence variable. There are total `2*3=6` different combinations of evidence. Note the values of `evidence` affects probability of `Cancer`. We have 6 columns in `values` denoting conditional probabilities. `values` sums up to 1 columnwise. The columns are partitioned by the first evidence first. The first three columns correspond to non-smokers.\n",
    "\n",
    "We next consider another discrete variable `D`, with a model assumption that it depends on `A`, `B` and `C`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "cpd = TabularCPD(\n",
    "    variable=\"D\",\n",
    "    variable_card=2,\n",
    "    values=[[0.20, 0.15, 0.93, 0.05, 0.001, 0.02, 0.10, 0.25 ],\n",
    "            [0.80, 0.85, 0.07, 0.95, 0.999, 0.98, 0.90, 0.75 ]],\n",
    "    evidence=[\"A\", \"B\", \"C\"],\n",
    "    evidence_card=[2, 2, 2],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------+------+------+------+------+-------+------+------+------+\n",
      "| A    | A(0) | A(0) | A(0) | A(0) | A(1)  | A(1) | A(1) | A(1) |\n",
      "+------+------+------+------+------+-------+------+------+------+\n",
      "| B    | B(0) | B(0) | B(1) | B(1) | B(0)  | B(0) | B(1) | B(1) |\n",
      "+------+------+------+------+------+-------+------+------+------+\n",
      "| C    | C(0) | C(1) | C(0) | C(1) | C(0)  | C(1) | C(0) | C(1) |\n",
      "+------+------+------+------+------+-------+------+------+------+\n",
      "| D(0) | 0.2  | 0.15 | 0.93 | 0.05 | 0.001 | 0.02 | 0.1  | 0.25 |\n",
      "+------+------+------+------+------+-------+------+------+------+\n",
      "| D(1) | 0.8  | 0.85 | 0.07 | 0.95 | 0.999 | 0.98 | 0.9  | 0.75 |\n",
      "+------+------+------+------+------+-------+------+------+------+\n"
     ]
    }
   ],
   "source": [
    "print(cpd)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each column of `values` correspond to a different combinations of `evidence` variables. If the evidence variables were lexicographic as above `(A, B, C)`, then the columns will be also lexicographic. \n",
    "\n",
    "```(A(0)B(0)C(0), A(0)B(0)C(1), A(0)B(1)C(0), A(0)B(1)C(1), ...)```\n",
    "\n",
    "It is as though we loop over A, and B, and C, in a nested order. \n",
    "\n",
    "### CPD with random values\n",
    "It is also possible to create a random tabular CPD."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.55203474],\n",
       "       [0.44796526]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cpd_coin_random = TabularCPD.get_random(variable=\"coin\", cardinality = {'coin':2}, state_names={'Coin': ['Head', 'Tail']})\n",
    "cpd_coin_random.get_values()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note the probabilities can be retrieved by `get_values` and they add up to 1 again. If `cardinality` is missing, the discrete variable is assumed to be binary. \n",
    "\n",
    "When there are `evidence`, the cardinality of evidence needs to be also specified."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.30404579, 0.99344749, 0.47933983, 0.58564998, 0.54269003,\n",
       "        0.28164646],\n",
       "       [0.69595421, 0.00655251, 0.52066017, 0.41435002, 0.45730997,\n",
       "        0.71835354]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cpd_cancer_random = TabularCPD.get_random(\n",
    "    variable=\"Cancer\",\n",
    "    cardinality = {\"Cancer\": 2, \"Smoker\": 2, \"Pollution\": 3},\n",
    "    evidence=[\"Smoker\", \"Pollution\"],\n",
    ")\n",
    "cpd_cancer_random.get_values()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1., 1., 1., 1., 1., 1.])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cpd_cancer_random.get_values().sum(axis = 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note again the probability table sums up to 1 for each column.\n",
    "\n",
    "## LinearGaussianCPD for continuous variables\n",
    "\n",
    "We define CPDs for each variable. `LinearGaussianCPD` assumes that each variable takes normal distribution. \n",
    "Assume `Healthy` and `Wealthy` variables have no parents (empty evidence). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<LinearGaussianCPD: P(Healthy) = N(4; 2) at 0x15735ed20\n",
      "<LinearGaussianCPD: P(Wealthy) = N(2; 3) at 0x1566c2930\n"
     ]
    }
   ],
   "source": [
    "# Step 2: Define the CPDs.\n",
    "from pgmpy.factors.continuous import LinearGaussianCPD\n",
    "cpd_healthy = LinearGaussianCPD(variable=\"Healthy\", beta=[4], std=2, evidence=[])\n",
    "cpd_wealthy = LinearGaussianCPD(variable=\"Wealthy\", beta=[2], std=3, evidence=[])\n",
    "\n",
    "import pprint\n",
    "pprint.pp(cpd_healthy)\n",
    "pprint.pp(cpd_wealthy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Above we defined, Healthy as normal distribution with mean 4 and standard deviation 2, and wealthy with mean 2, and standard deviation 3. \n",
    "(Assume bigger variation in people's wealth than health, bigger mean for health than wealth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<LinearGaussianCPD: P(Happy | Healthy, Wealthy) = N(3*Healthy + 2*Wealthy + 1; 5) at 0x154160050\n"
     ]
    }
   ],
   "source": [
    "cpd_happy = LinearGaussianCPD(\n",
    "    variable=\"Happy\",\n",
    "    beta=[1, 3, 2],\n",
    "    std=5,\n",
    "    evidence=[\"Healthy\", \"Wealthy\"],\n",
    ")\n",
    "pprint.pp(cpd_happy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `Happy` variable has mean `3*Healthy + 2*Wealthy + 1`, this formula is determined by passing `beta` and `evidence` variables. Note that the first element of `beta` is the intercept (constant term). The rest of the elements in `beta` each match the evidence.\n",
    "The standard deviation of normal distribution of `Happy` is set by `std`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pgmpy",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
